{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Randomized Search with scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "X = iris.data[:,2:]\n",
    "y = iris.target\n",
    "\n",
    "from sklearn.model_selection import train_test_split, cross_val_score\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, stratify = y,random_state = 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "knn_clf = KNeighborsClassifier()\n",
    "param_dist = {'n_neighbors': list(range(3,9,1))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomizedSearchCV(cv=10, error_score='raise',\n",
       "          estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "           metric_params=None, n_jobs=1, n_neighbors=5, p=2,\n",
       "           weights='uniform'),\n",
       "          fit_params=None, iid=True, n_iter=6, n_jobs=1,\n",
       "          param_distributions={'n_neighbors': [3, 4, 5, 6, 7, 8]},\n",
       "          pre_dispatch='2*n_jobs', random_state=None, refit=True,\n",
       "          return_train_score=True, scoring=None, verbose=0)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "rs = RandomizedSearchCV(knn_clf,param_dist,cv=10,n_iter=6)\n",
    "rs.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_neighbors': 3}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rs.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[({'n_neighbors': 3}, 0.9553571428571429),\n",
       " ({'n_neighbors': 4}, 0.9375),\n",
       " ({'n_neighbors': 5}, 0.9553571428571429),\n",
       " ({'n_neighbors': 6}, 0.9553571428571429),\n",
       " ({'n_neighbors': 7}, 0.9553571428571429),\n",
       " ({'n_neighbors': 8}, 0.9553571428571429)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zip(rs.cv_results_['params'],rs.cv_results_['mean_test_score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_neighbors': 3}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_dist = {'n_neighbors': list(range(3,50,1))}\n",
    "rs = RandomizedSearchCV(knn_clf,param_dist,cv=10,n_iter=15)\n",
    "rs.fit(X_train,y_train)\n",
    "rs.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 loop, best of 3: 694 ms per loop\n"
     ]
    }
   ],
   "source": [
    "%timeit rs.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 loop, best of 3: 2.17 s per loop\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "param_grid = {'n_neighbors': list(range(3,50,1))}\n",
    "gs = GridSearchCV(knn_clf,param_grid,cv=10)\n",
    "%timeit gs.fit(X_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_neighbors': 3}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs.best_params_ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[({'n_neighbors': 3}, 0.9553571428571429),\n",
       " ({'n_neighbors': 4}, 0.9375),\n",
       " ({'n_neighbors': 5}, 0.9553571428571429),\n",
       " ({'n_neighbors': 6}, 0.9553571428571429),\n",
       " ({'n_neighbors': 7}, 0.9553571428571429),\n",
       " ({'n_neighbors': 8}, 0.9553571428571429),\n",
       " ({'n_neighbors': 9}, 0.9553571428571429),\n",
       " ({'n_neighbors': 10}, 0.9464285714285714),\n",
       " ({'n_neighbors': 11}, 0.9553571428571429),\n",
       " ({'n_neighbors': 12}, 0.9553571428571429),\n",
       " ({'n_neighbors': 13}, 0.9553571428571429),\n",
       " ({'n_neighbors': 14}, 0.9553571428571429),\n",
       " ({'n_neighbors': 15}, 0.9553571428571429),\n",
       " ({'n_neighbors': 16}, 0.9553571428571429),\n",
       " ({'n_neighbors': 17}, 0.9464285714285714),\n",
       " ({'n_neighbors': 18}, 0.9464285714285714),\n",
       " ({'n_neighbors': 19}, 0.9375),\n",
       " ({'n_neighbors': 20}, 0.9464285714285714),\n",
       " ({'n_neighbors': 21}, 0.9464285714285714),\n",
       " ({'n_neighbors': 22}, 0.9464285714285714),\n",
       " ({'n_neighbors': 23}, 0.9464285714285714),\n",
       " ({'n_neighbors': 24}, 0.9464285714285714),\n",
       " ({'n_neighbors': 25}, 0.9464285714285714),\n",
       " ({'n_neighbors': 26}, 0.9553571428571429),\n",
       " ({'n_neighbors': 27}, 0.9553571428571429),\n",
       " ({'n_neighbors': 28}, 0.9464285714285714),\n",
       " ({'n_neighbors': 29}, 0.9285714285714286),\n",
       " ({'n_neighbors': 30}, 0.9464285714285714),\n",
       " ({'n_neighbors': 31}, 0.9375),\n",
       " ({'n_neighbors': 32}, 0.9375),\n",
       " ({'n_neighbors': 33}, 0.9375),\n",
       " ({'n_neighbors': 34}, 0.9375),\n",
       " ({'n_neighbors': 35}, 0.9375),\n",
       " ({'n_neighbors': 36}, 0.9375),\n",
       " ({'n_neighbors': 37}, 0.9375),\n",
       " ({'n_neighbors': 38}, 0.9464285714285714),\n",
       " ({'n_neighbors': 39}, 0.9375),\n",
       " ({'n_neighbors': 40}, 0.9375),\n",
       " ({'n_neighbors': 41}, 0.9375),\n",
       " ({'n_neighbors': 42}, 0.9375),\n",
       " ({'n_neighbors': 43}, 0.9375),\n",
       " ({'n_neighbors': 44}, 0.9375),\n",
       " ({'n_neighbors': 45}, 0.9464285714285714),\n",
       " ({'n_neighbors': 46}, 0.9375),\n",
       " ({'n_neighbors': 47}, 0.9464285714285714),\n",
       " ({'n_neighbors': 48}, 0.9464285714285714),\n",
       " ({'n_neighbors': 49}, 0.9464285714285714)]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zip(gs.cv_results_['params'],gs.cv_results_['mean_test_score'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
